Skip to content

feat(magi): honor AttnMaskSpec on the HF attention backend#2622

Open
HuiyingLi wants to merge 2 commits into
mainfrom
huiyingl/feat/magi-hf-attn-spec
Open

feat(magi): honor AttnMaskSpec on the HF attention backend#2622
HuiyingLi wants to merge 2 commits into
mainfrom
huiyingl/feat/magi-hf-attn-spec

Conversation

@HuiyingLi

Copy link
Copy Markdown
Contributor

What

Brings the HF magi attention backend (model.attn_implementation=magi) to parity with the custom-model factory path (model.backend.attn=magi): it now honors an arbitrary AttnMaskSpec (sequence packing / sliding-window / prefix-tree masks), and fails loudly instead of silently dropping a mask.

Why

make_magi_attn_func (custom-model factory) already reads the active AttnMaskSpec and builds the FFA flex key from it. The HF-registered forward (magi_attention_forward, the "magi" entry in ALL_ATTENTION_FUNCTIONS) did not — it only built a plain causal / dispatched key. Consequences:

  • attn_implementation=magi silently dropped any non-causal mask (packing / prefix-tree).
  • A custom model whose attention dispatches on config._attn_implementation (e.g. the registered custom Qwen2) configured with backend.attn=magi falls back to its default attention (flash_attention_2) and drops the mask with no error — the magi forward is never even entered.

What changed

  • magi_attention_forward reads the per-step AttnMaskSpec stamped on the attention module and builds the flex key from it (cp_size==1). The mask rides on module (already in the HF attention signature) — no process-global is read inside the interface.
  • _set_attn_spec_on_attention(model, spec) stamps the spec on the language-backbone attention modules (sibling of _set_cp_group_on_attention), and arms a consumption guard.
  • Consumption guard: if a spec was armed for a step but no magi forward consumed it (model silently used a non-magi attention), the next step raises a clear error instead of training on the wrong mask.

No behavior change when no spec is active (the default) — _magi_attn_spec is unset → identical to before.

Enables / relationship to #2564

This is the generic integration piece extracted so that the cp=1 prefix-tree rollout feature (#2564) works through the HF path. With this landed, #2564 stamps its prefix-tree spec via _set_attn_spec_on_attention and uses attn_implementation=magi.

Verification

Scope

cp_size==1 (prefix-tree / arbitrary masks); the GPU forward is # pragma: no cover (exercised by the FFA parity tests).

🤖 Generated with Claude Code

The custom-model magi attn_func reads the active AttnMaskSpec (packing /
sliding-window / prefix-tree masks via the flex key), but the HF-registered
magi forward did not -- so attn_implementation="magi" silently dropped any
non-causal mask while backend.attn="magi" applied it. Worse, a model whose
attention dispatches on config._attn_implementation (e.g. the custom Qwen2)
with backend.attn="magi" falls back to its default attention and drops the
mask with no error.

Bring the HF forward to parity: it now reads the per-step AttnMaskSpec stamped
on the attention module by _set_attn_spec_on_attention() and builds the flex
key from it (cp_size==1; the mask rides on `module`, already in the HF
attention signature, so no process-global is read inside the interface).

Add a consumption guard: a spec armed for a step but never read by a magi
forward raises on the next step, turning the silent non-magi fallback into a
loud error. CPU unit tests cover the stamping + guard; the GPU forward is
exercised by the FFA parity tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi HuiyingLi requested a review from a team as a code owner June 17, 2026 12:11
@copy-pr-bot

copy-pr-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Add CPU tests (magi_attention stubbed) that the registered "magi" HF forward
builds the flex key and marks the spec consumed when _magi_attn_spec is on the
module, and falls back to the dispatched key otherwise.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/claude review

1 similar comment
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/claude review

@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 81ef112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant